import numpy as np

def param_groups_UniDense(model, weight_decay, base_lr, AWL = None):
    parameter_groups = {}

    if AWL is not None:
        parameter_groups['AWL'] = {
            "lr_scale": 1.0, 
            "weight_decay": weight_decay,
            "params": [],
        }
        for n, p in AWL.named_parameters():
            parameter_groups['AWL']["params"].append(p)

    no_decay_names = ['relative_position_bias_table', 'rpe_mlp', 'logit_scale']
    print("Build LDMOptimizerConstructor")

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
            group_name = "no_decay"
            this_weight_decay = 0.
        else:
            group_name = "decay"
            this_weight_decay = weight_decay

            for nd_name in no_decay_names:
                if nd_name in name:
                    group_name = "no_decay"
                    this_weight_decay = 0.
                    break
        
        if ('conv_layer_task_specific' in name or 'unet' in name or 'aggregator' in name):
            layer_id = 2
        # elif 'unet' in name:
        #     layer_id = 0
        else:
            layer_id = 1
        group_name = "layer_%d_%s" % (layer_id, group_name)

        if group_name not in parameter_groups:
            if (layer_id == 0): # fine-tune
                scale = 0.01
            elif layer_id == 1: # fix
                scale = 0.0
            else:
                scale = 1.0

            parameter_groups[group_name] = {
                "weight_decay": this_weight_decay,
                "params": [],
                "param_names": [],
                "lr_scale": scale,
                "group_name": group_name,
                "lr": scale * base_lr,
            }

        parameter_groups[group_name]["params"].append(param)
        parameter_groups[group_name]["param_names"].append(name)
    
    return list(parameter_groups.values())


def param_groups_UniDense_meta(model, weight_decay, base_lr, AWL = None):
    parameter_groups = {}

    if AWL is not None:
        parameter_groups['AWL'] = {
            "lr_scale": 1.0, 
            "weight_decay": weight_decay,
            "params": [],
        }
        for n, p in AWL.named_parameters():
            parameter_groups['AWL']["params"].append(p)

    no_decay_names = ['relative_position_bias_table', 'rpe_mlp', 'logit_scale']
    print("Build LDMOptimizerConstructor")

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
            group_name = "no_decay"
            this_weight_decay = 0.
        else:
            group_name = "decay"
            this_weight_decay = weight_decay

            for nd_name in no_decay_names:
                if nd_name in name:
                    group_name = "no_decay"
                    this_weight_decay = 0.
                    break
        
        if ('f_gate' in name or 'conv_layer_meta' in name or 'task_extractor' in name):
            layer_id = 2
        else:
            layer_id = 1
        group_name = "layer_%d_%s" % (layer_id, group_name)

        if group_name not in parameter_groups:
            if (layer_id == 0): # fine-tune
                scale = 0.01
            elif layer_id == 1: # fix
                scale = 0.0
            else:
                scale = 1.0

            parameter_groups[group_name] = {
                "weight_decay": this_weight_decay,
                "params": [],
                "param_names": [],
                "lr_scale": scale,
                "group_name": group_name,
                "lr": scale * base_lr,
            }

        parameter_groups[group_name]["params"].append(param)
        parameter_groups[group_name]["param_names"].append(name)
    
    return list(parameter_groups.values())


def param_groups_UniDense_ft(model, weight_decay, base_lr, AWL = None):
    parameter_groups = {}

    if AWL is not None:
        parameter_groups['AWL'] = {
            "lr_scale": 1.0, 
            "weight_decay": weight_decay,
            "params": [],
        }
        for n, p in AWL.named_parameters():
            parameter_groups['AWL']["params"].append(p)

    no_decay_names = ['relative_position_bias_table', 'rpe_mlp', 'logit_scale']
    print("Build LDMOptimizerConstructor")

    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue  # frozen weights
        if len(param.shape) == 1 or name.endswith(".bias") or name in ('absolute_pos_embed'):
            group_name = "no_decay"
            this_weight_decay = 0.
        else:
            group_name = "decay"
            this_weight_decay = weight_decay

            for nd_name in no_decay_names:
                if nd_name in name:
                    group_name = "no_decay"
                    this_weight_decay = 0.
                    break
        
        if ('f_gate' in name or 'conv_layer_task_specific' in name):
            layer_id = 2
        elif 'unet' in name:
            layer_id = 0
        else:
            layer_id = 1
        group_name = "layer_%d_%s" % (layer_id, group_name)

        if group_name not in parameter_groups:
            if (layer_id == 0): # fine-tune
                scale = 0.01
            elif layer_id == 1: # fix
                scale = 0.0
            else:
                scale = 1.0

            parameter_groups[group_name] = {
                "weight_decay": this_weight_decay,
                "params": [],
                "param_names": [],
                "lr_scale": scale,
                "group_name": group_name,
                "lr": scale * base_lr,
            }

        parameter_groups[group_name]["params"].append(param)
        parameter_groups[group_name]["param_names"].append(name)
    
    return list(parameter_groups.values())